from scipy.interpolate import griddata
def interpolate_image_single_channel(image, missing_mask, method):
rows, cols = image.shape
points = []
values = []
for i in range(rows):
for j in range(cols):
if missing_mask[i, j]:
continue
points.append([i, j])
values.append(image[i, j])
points = np.array(points)
values = np.array(values)
grid_x, grid_y = np.mgrid[0:rows, 0:cols]
interpolated_grid = griddata(points, values, (grid_x, grid_y), method=method)
return np.clip(interpolated_grid, 0, 1)
def interpolate_image(image, missing_mask, method):
image_r = interpolate_image_single_channel(image[:,:,0], missing_mask, method)
image_g = interpolate_image_single_channel(image[:,:,1], missing_mask, method)
image_b = interpolate_image_single_channel(image[:,:,2], missing_mask, method)
return np.stack([image_r, image_g, image_b], axis=2)
# Interpolate the missing values using linear interpolation
interpolated_image_nearest = interpolate_image(image_with_nans, mask, 'nearest')
interpolated_image_linear = interpolate_image(image_with_nans, mask, 'linear')
interpolated_image_cubic = interpolate_image(image_with_nans, mask, 'cubic')
# 2x2 comparison
plt.figure(figsize=(14, 14))
plt.subplot(2, 2, 1)
plt.imshow(imputed_image)
plt.title('SVD Imputation')
plt.subplot(2, 2, 2)
plt.imshow(interpolated_image_nearest)
plt.title('Nearest Interpolation')
plt.subplot(2, 2, 3)
plt.imshow(interpolated_image_linear)
plt.title('Linear Interpolation')
plt.subplot(2, 2, 4)
plt.imshow(interpolated_image_cubic)
plt.title('Cubic Interpolation')
plt.tight_layout()